from transformers import AutoImageProcessor, ResNetModel
import torch
from datasets import load_dataset
from PIL import Image

# dataset = load_dataset("huggingface/cats-image")
# image = dataset["test"]["image"][0]

image_path = "data/douyin_fake_comments_try/images/Au_001.jpg"
image = Image.open(image_path)
image.load()
print(type(image))

model_name_or_path = "../model/resnet-50"
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
model = ResNetModel.from_pretrained(model_name_or_path)

inputs = image_processor(image, return_tensors="pt")

print(inputs.keys())
# dict_keys(['pixel_values'])
print(type(inputs["pixel_values"]))
# <class 'torch.Tensor'>
print(inputs["pixel_values"].shape)
# torch.Size([1, 3, 224, 224])

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state
pooler_output = outputs.pooler_output
print(list(last_hidden_states.shape))
# [1, 2048, 7, 7]
print(list(pooler_output.shape))
# [1, 2048, 1, 1]
